{ "cells": [ { "cell_type": "markdown", "id": "a1b2c3d4", "metadata": {}, "source": [ "# Tutorial 01 — Quick Start\n\nIn this tutorial we will train an audio classifier from scratch on the [GTZAN](https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification) dataset consisting of audio samples from 10 different music genres. \n\nBy the end of this tutorial we will have:\n- Loaded an audio dataset from a directory\n- Built and trained an `AudioClassifier` with a pretrained backbone\n- Evaluated model accuracy on a held-out test set\n- Run single-file inference on unseen data\n\n---\n\n**Assumed dataset layout**\n\nDeepAudioX expects the dataset to be in the following directory structure.\n\n```\ngtzan/\n├── train/\n│ ├── blues/\n│ │ ├── blues.00000.wav\n│ │ ├── blues.00001.wav\n│ │ └── ...\n│ ├── classical/\n│ ├── country/\n│ ├── disco/\n│ ├── hiphop/\n│ ├── jazz/\n│ ├── metal/\n│ ├── pop/\n│ ├── reggae/\n│ └── rock/\n└── test/\n ├── blues/\n └── ...\n```\n\nEach sub-folder name becomes the class label. See Tutorial 02 for an alternative loading strategy using Python dictionaries." ] }, { "cell_type": "markdown", "id": "d4e5f6a7", "metadata": {}, "source": [ "## 1. Configuration\n", "\n", "Update the two directory paths below and choose your target device before running any other cell." ] }, { "cell_type": "code", "execution_count": 1, "id": "e5f6a7b8", "metadata": {}, "outputs": [], "source": [ "TRAIN_DIR = \"/data/gtzan/train\" # directory containing class sub-folders\n", "TEST_DIR = \"/data/gtzan/test\"\n", "SAMPLE_RATE = 32_000 # Sampling Rate to use when loading audio files\n", "CHECKPOINT = \"checkpoint.pt\" # path where the best model will be saved\n", "DEVICE = \"cuda\" # \"cuda\" | \"mps\" | \"cpu\"" ] }, { "cell_type": "markdown", "id": "f6a7b8c9", "metadata": {}, "source": [ "## 2. Loading the Dataset\n", "\n", "`get_class_mapping_from_dir` scans the top-level sub-folders of the training directory and builds a `{class_name: int}` mapping automatically. This mapping must be passed to every dataset, model, and inference call — it is the single source of truth for label ordering throughout the project." ] }, { "cell_type": "code", "execution_count": 2, "id": "a7b8c9d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 classes detected: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']\n" ] } ], "source": [ "from deepaudiox import get_class_mapping_from_dir\n", "\n", "class_mapping = get_class_mapping_from_dir(TRAIN_DIR)\n", "print(f\"{len(class_mapping)} classes detected: {list(class_mapping.keys())}\")" ] }, { "cell_type": "markdown", "id": "587c86b4", "metadata": {}, "source": [ "DeepAudioX allows you to easily create a PyTorch Dataset for Audio Classification. This done via the `audio_classification_dataset_from_dir` method." ] }, { "cell_type": "code", "execution_count": 3, "id": "b8c9d0e1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train samples : 799\n", "Test samples : 195\n" ] } ], "source": [ "from deepaudiox import audio_classification_dataset_from_dir\n", "\n", "train_dataset = audio_classification_dataset_from_dir(\n", " root_dir=TRAIN_DIR,\n", " sample_rate=SAMPLE_RATE,\n", " class_mapping=class_mapping,\n", ")\n", "\n", "test_dataset = audio_classification_dataset_from_dir(\n", " root_dir=TEST_DIR,\n", " sample_rate=SAMPLE_RATE,\n", " class_mapping=class_mapping,\n", ")\n", "\n", "print(f\"Train samples : {len(train_dataset)}\")\n", "print(f\"Test samples : {len(test_dataset)}\")" ] }, { "cell_type": "markdown", "id": "2d66d10f", "metadata": {}, "source": [ "Each item returned by the dataset is a dictionary containing:\n", "\n", "```python\n", "{\n", " \"path\": str, # File path of the audio\n", " \"y_true\": int, # Integer class ID\n", " \"class_name\": str, # String class label\n", " \"segment_idx\": int, # Segment index (for segmented audio)\n", " \"feature\": np.ndarray # Audio waveform as numpy array\n", "}\n", "```" ] }, { "cell_type": "code", "execution_count": 4, "id": "18e5c283", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "path : /data/gtzan/train/blues/blues.00056.wav\n", "y_true : 0\n", "class_name : blues\n", "segment_idx : 0\n", "feature : [-0.15826386 -0.26958048 -0.29068074 ... -0.37689695 -0.36377403\n", " -0.12435105]\n" ] } ], "source": [ "item = train_dataset[0]\n", "for key, value in item.items():\n", " print(key, \":\", value)" ] }, { "cell_type": "markdown", "id": "79b24754", "metadata": {}, "source": [ "## 3. Working with segmented-audio files" ] }, { "cell_type": "markdown", "id": "38b88f68", "metadata": {}, "source": [ "The `audio_classification_dataset_from_dir` provides an argument `segment_duration` where you can control the total duration of each sample returned by `__getitem__`. \n", "\n", "It is customary in audio to work with small segments (e.g., < 10 sec) to capture temporal changes within the track. If you provide a specific `segment_duration` then each segment is treated as an independent sample in the dataset, with the same class label as the original audio file. The `segment_idx` field in the dataset output indicates which segment a sample corresponds to. We re-initialize the datasets using an audio segment duration equal to 3 seconds." ] }, { "cell_type": "code", "execution_count": 5, "id": "afea00d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train samples : 7981\n", "Test samples : 1950\n" ] } ], "source": [ "train_dataset = audio_classification_dataset_from_dir(\n", " root_dir=TRAIN_DIR,\n", " sample_rate=SAMPLE_RATE,\n", " class_mapping=class_mapping,\n", " segment_duration=3.0\n", ")\n", "\n", "test_dataset = audio_classification_dataset_from_dir(\n", " root_dir=TEST_DIR,\n", " sample_rate=SAMPLE_RATE,\n", " class_mapping=class_mapping,\n", " segment_duration=3.0\n", ")\n", "\n", "print(f\"Train samples : {len(train_dataset)}\")\n", "print(f\"Test samples : {len(test_dataset)}\")" ] }, { "cell_type": "markdown", "id": "c9d0e1f2", "metadata": {}, "source": [ "## 4. Building the Classifier\n", "\n", "Use the `AudioClassifier` class to easily build an audio classifier with a pretrained backbone. The only parameters you'll need to specify is the `num_classes` for the classifier head and the `sample_rate` to be used internally by the backbones to extract the spectro-temporal features from the raw audio waveforms.\n", "\n", "We use a **PASST** as backbone pre-trained on AudioSet and freeze its weights — only the lightweight classifier head will be updated. This keeps training fast enough to run on a laptop GPU or even a CPU." ] }, { "cell_type": "code", "execution_count": 6, "id": "d0e1f2a3", "metadata": {}, "outputs": [], "source": [ "from deepaudiox import AudioClassifier\n", "\n", "model = AudioClassifier(\n", " num_classes=len(class_mapping),\n", " backbone=\"passt\",\n", " pretrained=True, # load pretrained weights for the backbone\n", " freeze_backbone=True, # only the classifier head is trained\n", " sample_rate=SAMPLE_RATE,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "49e7078e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AudioClassifierConstructor(\n", " (backbone_constructor): BackboneConstructor(\n", " (backbone): PaSST(\n", " (feature_extractor): AugmentMelSTFT(\n", " winsize=800, hopsize=320\n", " (freqm): FrequencyMasking()\n", " (timem): TimeMasking()\n", " )\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (blocks): Sequential(\n", " (0): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (6): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (7): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (8): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (9): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (10): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (11): Block(\n", " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=768, out_features=768, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (drop_path): Identity()\n", " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", " (act): GELU(approximate='none')\n", " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", " (drop): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", " (pre_logits): Identity()\n", " )\n", " (pooling): GAP()\n", " )\n", " (classifier): MLPHead(\n", " (model): Sequential(\n", " (0): Linear(in_features=768, out_features=10, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Below you can check the architecture of the model.\n", "\n", "model" ] }, { "cell_type": "markdown", "id": "e1f2a3b4", "metadata": {}, "source": [ "## 5. Training\n", "\n", "`Trainer` class handles the complete training loop out of the box:\n", "\n", "| Feature | Default |\n", "|---|---|\n", "| Train / validation split | 80 / 20 (automatic) |\n", "| Optimizer | Adam, lr = 1e-3 |\n", "| LR scheduler | ReduceLROnPlateau |\n", "| Loss | Cross-Entropy |\n", "| Early stopping | `patience` epochs of non-improving val loss |\n", "| Checkpointing | Best model saved to `path_to_checkpoint` |\n", "\n", "All defaults can be overridden — see Tutorial 04 for advanced configuration." ] }, { "cell_type": "code", "execution_count": 8, "id": "f2a3b4c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using GPU: NVIDIA GeForce RTX 4090\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Epoch 1/30]\n", "Epoch 1 | Train Loss: 1.7399 | Val. Loss: 1.1800 | Time: 19.64s \n", "[CHECKPOINTER] Validation loss decreased: (inf --> 1.180008), \u001b[92m(-nan%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 2/30]\n", "Epoch 2 | Train Loss: 1.1476 | Val. Loss: 0.8874 | Time: 18.99s \n", "[CHECKPOINTER] Validation loss decreased: (1.180008 --> 0.887417), \u001b[92m(-24.80%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 3/30]\n", "Epoch 3 | Train Loss: 0.9254 | Val. Loss: 0.7591 | Time: 19.41s \n", "[CHECKPOINTER] Validation loss decreased: (0.887417 --> 0.759138), \u001b[92m(-14.46%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 4/30]\n", "Epoch 4 | Train Loss: 0.8158 | Val. Loss: 0.7005 | Time: 19.03s \n", "[CHECKPOINTER] Validation loss decreased: (0.759138 --> 0.700534), \u001b[92m(-7.72%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 5/30]\n", "Epoch 5 | Train Loss: 0.7439 | Val. Loss: 0.6588 | Time: 19.73s \n", "[CHECKPOINTER] Validation loss decreased: (0.700534 --> 0.658843), \u001b[92m(-5.95%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 6/30]\n", "Epoch 6 | Train Loss: 0.6974 | Val. Loss: 0.6386 | Time: 19.19s \n", "[CHECKPOINTER] Validation loss decreased: (0.658843 --> 0.638621), \u001b[92m(-3.07%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 7/30]\n", "Epoch 7 | Train Loss: 0.6623 | Val. Loss: 0.6237 | Time: 19.12s \n", "[CHECKPOINTER] Validation loss decreased: (0.638621 --> 0.623666), \u001b[92m(-2.34%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 8/30]\n", "Epoch 8 | Train Loss: 0.6299 | Val. Loss: 0.6037 | Time: 19.14s \n", "[CHECKPOINTER] Validation loss decreased: (0.623666 --> 0.603713), \u001b[92m(-3.20%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 9/30]\n", "Epoch 9 | Train Loss: 0.6011 | Val. Loss: 0.5913 | Time: 19.14s \n", "[CHECKPOINTER] Validation loss decreased: (0.603713 --> 0.591263), \u001b[92m(-2.06%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 10/30]\n", "Epoch 10 | Train Loss: 0.5858 | Val. Loss: 0.5845 | Time: 19.12s \n", "[CHECKPOINTER] Validation loss decreased: (0.591263 --> 0.584525), \u001b[92m(-1.14%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 11/30]\n", "Epoch 11 | Train Loss: 0.5679 | Val. Loss: 0.5834 | Time: 19.08s \n", "[CHECKPOINTER] Validation loss decreased: (0.584525 --> 0.583369), \u001b[92m(-0.20%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 12/30]\n", "Epoch 12 | Train Loss: 0.5546 | Val. Loss: 0.5725 | Time: 19.24s \n", "[CHECKPOINTER] Validation loss decreased: (0.583369 --> 0.572482), \u001b[92m(-1.87%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 13/30]\n", "Epoch 13 | Train Loss: 0.5458 | Val. Loss: 0.5639 | Time: 19.12s \n", "[CHECKPOINTER] Validation loss decreased: (0.572482 --> 0.563908), \u001b[92m(-1.50%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 14/30]\n", "Epoch 14 | Train Loss: 0.5314 | Val. Loss: 0.5577 | Time: 19.14s \n", "[CHECKPOINTER] Validation loss decreased: (0.563908 --> 0.557748), \u001b[92m(-1.09%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 15/30]\n", "Epoch 15 | Train Loss: 0.5154 | Val. Loss: 0.5573 | Time: 19.16s \n", "[CHECKPOINTER] Validation loss decreased: (0.557748 --> 0.557256), \u001b[92m(-0.09%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 16/30]\n", "Epoch 16 | Train Loss: 0.5136 | Val. Loss: 0.5517 | Time: 19.08s \n", "[CHECKPOINTER] Validation loss decreased: (0.557256 --> 0.551656), \u001b[92m(-1.01%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 17/30]\n", "Epoch 17 | Train Loss: 0.4960 | Val. Loss: 0.5537 | Time: 19.13s \n", "[Epoch 18/30]\n", "Epoch 18 | Train Loss: 0.4868 | Val. Loss: 0.5557 | Time: 18.84s \n", "[Epoch 19/30]\n", "Epoch 19 | Train Loss: 0.4890 | Val. Loss: 0.5368 | Time: 18.85s \n", "[CHECKPOINTER] Validation loss decreased: (0.551656 --> 0.536754), \u001b[92m(-2.70%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 20/30]\n", "Epoch 20 | Train Loss: 0.4828 | Val. Loss: 0.5426 | Time: 19.40s \n", "[Epoch 21/30]\n", "Epoch 21 | Train Loss: 0.4664 | Val. Loss: 0.5397 | Time: 18.88s \n", "[Epoch 22/30]\n", "Epoch 22 | Train Loss: 0.4576 | Val. Loss: 0.5293 | Time: 18.82s \n", "[CHECKPOINTER] Validation loss decreased: (0.536754 --> 0.529326), \u001b[92m(-1.38%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 23/30]\n", "Epoch 23 | Train Loss: 0.4554 | Val. Loss: 0.5226 | Time: 19.08s \n", "[CHECKPOINTER] Validation loss decreased: (0.529326 --> 0.522598), \u001b[92m(-1.27%)\u001b[0m.\n", "[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt\n", "[Epoch 24/30]\n", "Epoch 24 | Train Loss: 0.4541 | Val. Loss: 0.5298 | Time: 19.10s \n", "[Epoch 25/30]\n", "Epoch 25 | Train Loss: 0.4490 | Val. Loss: 0.5253 | Time: 18.86s \n", "[Epoch 26/30]\n", "Epoch 26 | Train Loss: 0.4373 | Val. Loss: 0.5240 | Time: 18.86s \n", "[Epoch 27/30]\n", "Epoch 27 | Train Loss: 0.4341 | Val. Loss: 0.5269 | Time: 18.94s \n", "[EARLY STOPPING] Elapsed epochs: 4 out of 5\n", "[Epoch 28/30]\n", "Epoch 28 | Train Loss: 0.4306 | Val. Loss: 0.5257 | Time: 18.84s \n", "[EARLY STOPPING] Elapsed epochs: 5 out of 5\n", "[EARLY STOPPING] Patience exceeded, early stoping ...\n", "Training has finished.\n" ] } ], "source": [ "from deepaudiox import Trainer\n", "\n", "trainer = Trainer(\n", " train_dset=train_dataset,\n", " model=model,\n", " epochs=30,\n", " patience=5,\n", " batch_size=64,\n", " path_to_checkpoint=CHECKPOINT,\n", " device=DEVICE\n", ")\n", "\n", "trainer.train() # Simply call the train method to start training!" ] }, { "cell_type": "markdown", "id": "a3b4c5d6", "metadata": {}, "source": [ "## 5. Evaluating on the Test Set\n", "\n", "Once training has finished we reload the **best checkpoint** using `AudioClassifier.from_checkpoint`, and initialize the `Evaluator` to check the performance of the classifier over the held-out test split.\n", "\n", "`evaluate()` prints per-class accuracy, macro metrics, and a full classification report." ] }, { "cell_type": "code", "execution_count": 9, "id": "b4c5d6e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using GPU: NVIDIA GeForce RTX 4090\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Testing has finished. \n", "[REPORTER] Class mapping: {'blues': 0, 'classical': 1, 'country': 2, 'disco': 3, 'hiphop': 4, 'jazz': 5, 'metal': 6, 'pop': 7, 'reggae': 8, 'rock': 9} \n", "\n", "[REPORTER] Classification Report: \n", "\n", " precision recall f1-score support\n", "\n", " blues 0.74 0.69 0.72 200\n", " classical 0.98 0.99 0.99 200\n", " country 0.66 0.95 0.78 200\n", " disco 0.86 0.89 0.88 200\n", " hiphop 0.91 0.98 0.94 200\n", " jazz 0.86 0.94 0.90 190\n", " metal 0.87 0.88 0.88 190\n", " pop 0.87 0.93 0.90 190\n", " reggae 0.90 0.51 0.65 190\n", " rock 0.73 0.55 0.63 190\n", "\n", " accuracy 0.83 1950\n", " macro avg 0.84 0.83 0.83 1950\n", "weighted avg 0.84 0.83 0.83 1950\n", "\n", "[REPORTER] Confusion Matrix: \n", "\n", "[[139 1 27 13 0 14 0 0 0 6]\n", " [ 0 198 0 0 0 1 0 1 0 0]\n", " [ 9 0 190 0 0 0 0 0 0 1]\n", " [ 2 0 3 178 6 0 0 3 5 3]\n", " [ 0 0 0 3 196 0 0 0 0 1]\n", " [ 0 2 2 0 1 179 0 1 2 3]\n", " [ 4 0 3 1 1 0 167 0 0 14]\n", " [ 1 0 4 4 0 0 0 176 0 5]\n", " [ 30 0 12 5 10 11 0 21 96 5]\n", " [ 2 1 48 2 1 2 24 1 4 105]]\n", "[REPORTER] Average Posteriors: \n", "\n", "blues : 0.722\n", "classical : 0.958\n", "country : 0.920\n", "disco : 0.841\n", "hiphop : 0.893\n", "jazz : 0.932\n", "metal : 0.903\n", "pop : 0.855\n", "reggae : 0.767\n", "rock : 0.639\n" ] } ], "source": [ "from deepaudiox import Evaluator\n", "\n", "model = AudioClassifier.from_checkpoint(CHECKPOINT)\n", "\n", "evaluator = Evaluator(\n", " test_dset=test_dataset,\n", " model=model,\n", " class_mapping=class_mapping,\n", " device=DEVICE,\n", ")\n", "\n", "evaluator.evaluate()" ] }, { "cell_type": "markdown", "id": "c5d6e7f8", "metadata": {}, "source": [ "## 6. Inference on a New Audio File\n", "\n", "DeepAudioX provides a flexible method `inference_on_file` to make predictions on unseen data. The method accepts any WAV or MP3 path and returns the predicted class label along with its posterior probability. \n", "\n", "The argument `segment_duration` allows us to specify the desired segment duration upon processing the entire file. For the best performance this is advisable to coincide with the same duration used in training.\n", "\n", "Here we grab the first `.wav` file from the test directory as a quick sanity check." ] }, { "cell_type": "code", "execution_count": 10, "id": "d6e7f8a9", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "sample_file = next(Path(TEST_DIR).rglob(\"*.wav\"))\n", "true_label = sample_file.parent.name\n", "\n", "result = model.inference_on_file(\n", " path=str(sample_file),\n", " sample_rate=SAMPLE_RATE,\n", " class_mapping=class_mapping,\n", " segment_duration=3.0, # same segment duration used during training\n", ")" ] }, { "cell_type": "markdown", "id": "a5d25066", "metadata": {}, "source": [ "The method returns a `dictionary` containing the `final_label` or the prediction. This is the result of the majority vote across all 3-second segments in the file. In addition, the dictionary contains information about each individual segment-level prediction." ] }, { "cell_type": "code", "execution_count": 11, "id": "ccd78cdb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'final_label': 'country',\n", " 'final_posterior': 0.959261554479599,\n", " 'segment_labels': ['country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'country',\n", " 'classical'],\n", " 'segment_posteriors': [0.9425491094589233,\n", " 0.9905596375465393,\n", " 0.9864566326141357,\n", " 0.9714261889457703,\n", " 0.9702315330505371,\n", " 0.973042905330658,\n", " 0.9922925233840942,\n", " 0.9042606353759766,\n", " 0.9638944864273071,\n", " 0.8979018926620483,\n", " 0.6707818508148193]}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result" ] } ], "metadata": { "kernelspec": { "display_name": "deepaudio-x (3.13.9)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.9" } }, "nbformat": 4, "nbformat_minor": 5 }